import logging
import os.path as osp
import time
from time import strftime, localtime

import numpy as np
import torch
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter

from .common import is_list, is_tensor, ts2np, mkdir, Odict, NoOp


class MessageManager:
    def __init__(self):
        self.info_dict = Odict()
        self.writer_hparams = ['image', 'scalar']
        self.time = time.time()

    def init_manager(self, save_path, log_to_file, log_iter, iteration=0):
        self.iteration = iteration
        self.log_iter = log_iter
        mkdir(osp.join(save_path, "summary/"))
        self.writer = SummaryWriter(
            osp.join(save_path, "summary/"), purge_step=self.iteration)
        self.init_logger(save_path, log_to_file)

    def init_logger(self, save_path, log_to_file):
        # init logger
        self.logger = logging.getLogger('openstereo')
        self.logger.setLevel(logging.INFO)
        self.logger.propagate = False
        formatter = logging.Formatter(
            fmt='[%(asctime)s] [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
        if log_to_file:
            mkdir(osp.join(save_path, "logs/"))
            vlog = logging.FileHandler(
                osp.join(save_path, "logs/", strftime('%Y-%m-%d-%H-%M-%S', localtime()) + '.txt'))
            vlog.setLevel(logging.INFO)
            vlog.setFormatter(formatter)
            self.logger.addHandler(vlog)

        console = logging.StreamHandler()
        console.setFormatter(formatter)
        console.setLevel(logging.DEBUG)
        self.logger.addHandler(console)

    def append(self, info):
        for k, v in info.items():
            v = [v] if not is_list(v) else v
            v = [ts2np(_) if is_tensor(_) else _ for _ in v]
            info[k] = v
        self.info_dict.append(info)

    def flush(self):
        self.info_dict.clear()
        self.writer.flush()

    def train_step(self, summary):
        self.write_to_tensorboard(summary, self.iteration)
        self.iteration += 1

    def write_to_tensorboard(self, summary, iteration=None):
        iteration = self.iteration if iteration is None else iteration
        for k, v in summary.items():
            module_name = k.split('/')[0]
            board_name = k.replace(module_name + "/", '')
            writer_module = getattr(self.writer, 'add_' + module_name)
            v = v.detach() if is_tensor(v) else v
            v = vutils.make_grid(v, normalize=True, scale_each=True) if 'image' in module_name else v
            writer_module(board_name, v, iteration)

    # def write_to_tensorboard(self, summary):
    #
    #     for k, v in summary.items():
    #         module_name = k.split('/')[0]
    #         if module_name not in self.writer_hparams:
    #             self.log_warning(
    #                 'Not Expected --Summary-- type [{}] appear!!!{}'.format(k, self.writer_hparams))
    #             continue
    #         board_name = k.replace(module_name + "/", '')
    #         writer_module = getattr(self.writer, 'add_' + module_name)
    #         v = v.detach() if is_tensor(v) else v
    #         v = vutils.make_grid(
    #             v, normalize=True, scale_each=True) if 'image' in module_name else v
    #         if module_name == 'scalar':
    #             try:
    #                 v = v.mean()
    #             except:
    #                 v = v
    #         writer_module(board_name, v, self.iteration)

    def log_training_info(self):
        now = time.time()
        string = "Iteration {:0>5}, Cost {:.2f}s".format(
            self.iteration, now - self.time, end="")
        for i, (k, v) in enumerate(self.info_dict.items()):
            if 'scalar' not in k:
                continue
            k = k.replace('scalar/', '').replace('/', '_')
            end = "\n" if i == len(self.info_dict) - 1 else ""
            string += ", {0}={1:.5f}".format(k, np.mean(v), end=end)
        self.log_info(string)
        self.reset_time()

    def reset_time(self):
        self.time = time.time()

    # def train_step(self, info, summary):
    #     self.iteration += 1
    #     self.append(info)
    #     if self.iteration % self.log_iter == 0:
    #         # self.log_training_info()
    #         self.flush()
    #         self.write_to_tensorboard(summary)

    def log_debug(self, *args, **kwargs):
        self.logger.debug(*args, **kwargs)

    def log_info(self, *args, **kwargs):
        self.logger.info(*args, **kwargs)

    def log_warning(self, *args, **kwargs):
        self.logger.warning(*args, **kwargs)


msg_mgr = MessageManager()
noop = NoOp()


def get_msg_mgr():
    if not torch.distributed.is_initialized():
        return msg_mgr
    elif torch.distributed.get_rank() > 0:
        return noop
    else:
        return msg_mgr
